/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
* \file batch_matmul_utils.h
* \brief
*/

#ifndef IMPL_MATMUL_UTILS_BATCH_MATMUL_UTILS_H
#define IMPL_MATMUL_UTILS_BATCH_MATMUL_UTILS_H

#include "matmul_config_utils.h"
#include "matmul_type_def.h"

namespace AscendC {

template <typename A_TYPE, const auto& MM_CFG>
constexpr bool IsBmmEnableScheduler = DoMatmulNorm(MM_CFG) &&
    ((A_TYPE::layout != LayoutMode::NONE && ToMatmulConfig(MM_CFG).batchMode == BatchMode::BATCH_LESS_THAN_L1) ||
    (A_TYPE::layout == LayoutMode::NORMAL && ToMatmulConfig(MM_CFG).batchMode == BatchMode::BATCH_LARGE_THAN_L1) ||
    (A_TYPE::layout == LayoutMode::NORMAL && ToMatmulConfig(MM_CFG).batchMode == BatchMode::SINGLE_LARGE_THAN_L1));

template <typename A_TYPE, const auto& MM_CFG>
constexpr bool IsBmmBatchScheduler = DoMatmulNorm(MM_CFG) &&
    ((A_TYPE::layout != LayoutMode::NONE && ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1));

template <typename A_TYPE, const auto& MM_CFG>
constexpr bool IsBmmSingleScheduler = DoMatmulNorm(MM_CFG) &&
    (A_TYPE::layout == LayoutMode::NORMAL && ToMatmulConfig(MM_CFG).batchMode == BatchMode::SINGLE_LARGE_THAN_L1);

struct BatchOffsetInfo
{
    int32_t modA;
    int32_t divisorA;
    int32_t alignA;
    int32_t modB;
    int32_t divisorB;
    int32_t alignB;
    int32_t modBias;
    int32_t divisorBias;
    int32_t alignBias;
    bool setBiasFlag {false};
};

struct SplitParams
{
    int16_t axisL1Len;
    int16_t kAxisL1Len;
    int16_t axisL1Offset;
    int16_t kAxisL1Offset;
    int16_t axisL0Len;
};

struct BatchSchedulerContext
{
    int32_t offsetA;
    int32_t offsetB;
    int32_t offsetBias;
    uint32_t reduceGNum;
    bool isReduceG;
    SplitParams aL0Params;
    SplitParams bL0Params;
};

struct BmmOffset {
    uint64_t offA = 0;
    uint64_t offB = 0;
    uint64_t offC = 0;
};

// It is invoked by the matmulV3 operator and cannot be removed at present
__aicore__ inline uint16_t CeilDiv(uint16_t num1, uint16_t num2)
{
    ASSERT(num2 > 0);
    return (num1 + num2 - 1) / num2;
}

// It is invoked by the matmulV3 operator and cannot be removed at present
__aicore__ inline uint16_t CeilAlign(uint16_t num1, uint16_t num2)
{
    ASSERT(num2 > 0);
    return CeilDiv(num1, num2) * num2;
}

template <class INPUT_TYPE>
__aicore__ inline uint64_t CalcNBatchoffset(uint32_t batchValue, uint32_t loopIdx, uint32_t layoutInfoN, uint32_t layoutInfoG, uint32_t layoutInfoD, uint32_t layoutInfoS)
{
    uint32_t alignedSingleCoreN = layoutInfoD;
    if constexpr (INPUT_TYPE::format == CubeFormat::ND_ALIGN) {
        alignedSingleCoreN = CeilAlign(layoutInfoD, AscendCUtils::GetC0Count(sizeof(typename INPUT_TYPE::T)));
    }
    uint64_t offset;
    if constexpr (INPUT_TYPE::layout == LayoutMode::BNGS1S2 || INPUT_TYPE::layout == LayoutMode::NORMAL) {
        offset = alignedSingleCoreN * layoutInfoS * batchValue * loopIdx * sizeof(typename INPUT_TYPE::T);
    } else if constexpr (INPUT_TYPE::layout == LayoutMode::SBNGD) {
        offset = alignedSingleCoreN * batchValue * loopIdx * sizeof(typename INPUT_TYPE::T);
    } else if constexpr (INPUT_TYPE::layout == LayoutMode::BSNGD) {
        uint64_t layoutBIdx = loopIdx * batchValue / (layoutInfoN * layoutInfoG);
        uint64_t layoutNGOff = loopIdx * batchValue % (layoutInfoN * layoutInfoG);
        offset = (layoutBIdx * alignedSingleCoreN * layoutInfoS * layoutInfoN * layoutInfoG + layoutNGOff * alignedSingleCoreN) * sizeof(typename INPUT_TYPE::T); 
    } 
    return offset;
}

__aicore__ inline uint64_t GetBatchCNum(uint32_t batchA, uint32_t batchB, uint32_t aLayoutInfoG, uint32_t bLayoutInfoG, uint32_t cLayoutInfoG)
{
    uint32_t batchC = batchA > batchB ? batchA : batchB;
    bool layoutGCondition = cLayoutInfoG == 1 &&
                            (aLayoutInfoG != 1 || bLayoutInfoG != 1);
    if (layoutGCondition) {
        int32_t layoutG = bLayoutInfoG > aLayoutInfoG ? bLayoutInfoG : aLayoutInfoG;
        batchC = batchC / layoutG;
    }
    return batchC;
}
} // namespace AscendC
#endif // IMPL_MATMUL_UTILS_BATCH_MATMUL_UTILS_H